#ifndef KDTREE_H
#define KDTREE_H
#include "ICP.h"
#include <algorithm>
#include <queue>

/*
class kdNode {
private:
	int axis_;
	Point3d point_;
	kdNode* left_;
	kdNode* right_;
public:	
    kdNode() {left_=right_=NULL;}
	kdNode(Point3d& point, int axis):point_(point),left_(NULL), right_(NULL), axis_(axis) {}
	~kdNode() {}
	int getAxis() {return axis_;}
	void setAxis(int axis) {axis_=axis;}
	Point3d getPoint() {return point_;}
	void setPoint(Point3d& point) {point_ = point;}
	kdNode* left() const {return left_;}
	void setLeft(kdNode* b) {left_ = b;}
	kdNode* right() const {return right_;}
	void setRight(kdNode* b) {right_ = b;}
	bool isLeaf() {return (NULL==left_) && (NULL==right_);}
	double getKval() {
		switch (axis_) {
			case(0): return point_.x;
			case(1): return point_.y;
			case(2): return point_.z;
	}
	
};

class kdTree {
private:
	vector<Point3d> E_;
	kdNode* root;
	
	static bool cmpx(Point3d a, Point3d b) {return a.x < b.x;}
	static bool cmpy(Point3d a, Point3d b) {return a.y < b.y;}
	static bool cmpz(Point3d a, Point3d b) {return a.z < b.z;}
	void sorted(int l, int r, int mode) {
	    switch (mode) {
	        case X:
	            sort(E_.begin()+l, E_.begin()+r+1, cmpx);
	            break;
	        case Y:
	            sort(E_.begin()+l, E_.begin()+r+1, cmpy);
	            break;
	        case Z:
	            sort(E_.begin()+l, E_.begin()+r+1, cmpz);
	            break;
	        default:
	            cout << "Error: The type is invalid!\n";
	    }
	}
	void build(int l, int r, kdNode* root, int axis) {
	    if (l == r) {
	    	root->setAxis(axis);
	    	root->setPoint(E_[l]); 
	    	root->setLeft(NULL); root->setRight(NULL); return;
	    }
	    sorted(l, r, axis%3);
	    int point_pos = (l+r) / 2;
	    Point3d point_med = E_[point_pos];
	    root->setPoint(point_med);
	    root->setAxis(axis);
	    axis = (axis+1) % 3;
	    build(l, point_pos-1, root->left(), axis);
	    build(point_pos+1, r, root->right(), axis);
	}
public:
	enum {X, Y, Z};

public:
	kdTree(vector<Point3d>& E) {
	    root = new kdNode();
	    E_ = E;
	    
	    build(0, E.size()-1, root, 0);
	}
	~kdTree() {}
	
	Fid
};
*/
struct NearestNode {
	int node;
	double distance;
	NearestNode() {node=distance=0;}
	NearestNode(int n, double d):node(n), distance(d) {}
};

struct cmp {
	bool operator()(NearestNode a, NearestNode b) {return a.distance < b.distance;}
};

class kdTree {
public:
	enum {X, Y, Z};
private:
	int *treePtr;
	int **tree;
	int treeRoot;
	int treeSize;
	int kDimension;
	vector<Point3d> data;
	double getKval(const Point3d& p, int axis) {
		switch (axis) {
			case(X): return p.x;
			case(Y): return p.y;
			case(Z): return p.z;
			default: cout << "Error!!\n"; return 0;
		}
	}
	
	double computeDistance(const Point3d& p, int node) {
		return distance(p, data[node]);
	}
	
public:
	kdTree(vector<Point3d>& B) {
		data = B;
		bool suc = setSize(3, B.size());
		// if (!suc) cout << "Error!!\n";
		// for (int i=0; i<treeSize; i++) cout << data[i] << endl;
		buildTree();
	}
	~kdTree() {}
	
	bool setSize(int dim, unsigned int size) {
		kDimension = dim;
		treeSize = size;
		
		if (kDimension > 0 && treeSize > 0) {
			tree = new int *[4];
			treePtr = new int[4 * treeSize];
			
			for (int i=0; i<4; i++) tree[i] = treePtr + i * treeSize;
		}
		
		return true;
	}
	
	int buildTree() {
		vector<int> vtr(treeSize);
		for (int i=0; i<treeSize; i++) vtr[i] = i;
		
		std::random_shuffle(vtr.begin(), vtr.end());

		treeRoot = buildTree(&vtr[0], treeSize, -1);
		
		return treeRoot;
	}
	
	int chooseSplitDimension(int *ids, int sz, double &key) {
		int split = 0;
		double *var = new double[kDimension];
		double *mean = new double[kDimension];
		
		double rt = 1.0 / sz;
		
		for (int i=0; i< kDimension; i++) {
			double sum1 = 0, sum2 = 0;
			for (int j=0; j<sz; j++) {
				sum1 += rt * getKval(data[ids[j]], i) * getKval(data[ids[j]], i);
				sum2 += rt * getKval(data[ids[j]], i);
			}
			var[i] = sum1 - sum2 * sum2;
			mean[i] = sum2;
		}
		
		double max = 0;
		for (int i=0; i<kDimension; i++) {
			if (var[i] > max) {
				key = mean[i];
				max = var[i];
				split = i;
			}
		}
		
		return split;
		
	}
	
	int chooseMiddleNode(int *ids, int sz, int dim, double key) {
		int left = 0;
		int right = sz - 1;
		
		while (true) {
			while (left <= right && getKval(data[ids[left]], dim) <= key) ++left;
			while (left <= right && getKval(data[ids[right]], dim) >= key) --right;
			if (left > right) break;
			std::swap(ids[left], ids[right]);
			++left;
			--right;
		}
		
		double max = -9999999;
		int maxIndex = 0;
		for (int i=0; i<left; i++) {
			if (getKval(data[ids[i]], dim) > max) {
				max = getKval(data[ids[i]], dim);
				maxIndex = i;
			}
		}
		
		if (maxIndex != left - 1) std::swap(ids[maxIndex], ids[left-1]);
		return left - 1;
	}
	
	int buildTree(int *indices, int count, int parent) {
		if (count == 1) {
			int rd = indices[0];
			tree[0][rd] = 0;
			tree[1][rd] = parent;
			tree[2][rd] = -1;
			tree[3][rd] = -1;
			
			return rd;
		}
		
		else {
			double key = 0;
			int split = chooseSplitDimension(indices, count, key);
			int idx = chooseMiddleNode(indices, count, split, key);
			
			int rd = indices[idx];
			tree[0][rd] = split;
			tree[1][rd] = parent;
			
			if (idx > 0) tree[2][rd] = buildTree(indices, idx, rd);
			else tree[2][rd] = -1;
			if (idx+1 < count) tree[3][rd] = buildTree(indices+idx+1, count-idx-1, rd);
			else tree[3][rd] = -1;
			return rd;
		}
		
	}
	
	int findKNearests(const Point3d& p, int k, int *res) {
		std::priority_queue<NearestNode, std::vector<NearestNode>, cmp> kNeighbors;
		std::stack<int> paths;
		
		int node = treeRoot;
		while (node > -1) {
			paths.emplace(node);
			node = getKval(p, tree[0][node]) <= getKval(data[node], tree[0][node]) ? tree[2][node] : tree[3][node];
		}
		
		kNeighbors.emplace(-1, 9999999);
		
		double distance = 0;
		while (!paths.empty()) {
			node = paths.top();
			paths.pop();
			distance = computeDistance(p, node);
			if (kNeighbors.size() < k) kNeighbors.emplace(node, distance);
			else {
				if (distance < kNeighbors.top().distance) {
					kNeighbors.pop();
					kNeighbors.emplace(node, distance);
				}
			}
			
			if (tree[2][node] + tree[3][node] > -2) {
				int dim = tree[0][node];
				if (getKval(p, dim) > getKval(data[node], dim)) {
					if (getKval(p, dim)-getKval(data[node], dim) < kNeighbors.top().distance && tree[2][node] > -1) {
						int reNode = tree[2][node];
						while (reNode > -1) {
							paths.emplace(reNode);
							reNode = getKval(p, tree[0][reNode]) <= getKval(data[reNode], tree[0][reNode]) ? tree[2][reNode] : tree[3][reNode];
						}
					}
				}
				
				else {
					if (getKval(data[node], dim)-getKval(p, dim) < kNeighbors.top().distance && tree[3][node] > -1) {
						int reNode = tree[3][node];
						while (reNode > -1) {
							paths.emplace(reNode);
							reNode = getKval(p, tree[0][reNode]) <= getKval(data[reNode], tree[0][reNode]) ? tree[2][reNode] : tree[3][reNode];
						}
					} 
				}
			}
			
		}
		
		// if (!res) res = new int[k];
		
		int i = kNeighbors.size();
		while (!kNeighbors.empty()) {
			res[--i] = kNeighbors.top().node;
			kNeighbors.pop();
		}
		return 0;
	}
	
};





































#endif

